# data/sampler.py
import random
from torch.utils.data import Sampler

class RandomSubsetSampler(Sampler):
    def __init__(self, data_source, num_samples):
        super().__init__()
        self.indices = list(range(len(data_source)))
        random.shuffle(self.indices)
        self.indices = self.indices[:num_samples]  # 控制最大样本数

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)